#!/bin/bash
#SBATCH --partition={partition}
#SBATCH --time={time_limit}
#SBATCH --nodes={num_nodes}
#SBATCH --gres=gpu:{gpus_per_node}
#SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task={cpus_per_node}
#SBATCH --account={account}
#SBATCH --output={experiments_dir}/logs/%x_%j.out
#SBATCH --job-name={job_name}
#SBATCH --mail-type=END,TIME_LIMIT,FAIL
#SBATCH --mail-user=dcft-slurm-notifs-aaaap7wt363mcsgryaejj2o6dm@dogs-and-ml.slack.com

# CUDA
module load CUDA/12.6.0
module load imkl/2023.2.0

# ENVIRONMENT
export DCFT=/rwthfs/rz/cluster/hpcwork/rwth1775/dcft
source $DCFT/dcft_private/hpc/dotenv/claix.env
source $DCFT/dcft_private/database/access.env
echo "Loading conda: $DCFT_PRIVATE_ACTIVATE_ENV"
$DCFT_PRIVATE_ACTIVATE_ENV

# NETWORKING
MASTER_ADDR=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1` # Allow communication over InfiniBand cells.
export MASTER_ADDR="${MASTER_ADDR}"
export MASTER_PORT=20043
export NCCL_NET_GDR_LEVEL=0 # for h100
export NCCL_SOCKET_IFNAME=ib0
export NCCL_IB_TIMEOUT=60

# RESOURCES
export CUDA_VISIBLE_DEVICES="0,1,2,3"
export OMP_NUM_THREADS=1
export SRUN_CPUS_PER_TASK=${SLURM_CPUS_PER_TASK}
export NUM_NODES=$SLURM_JOB_NUM_NODES
export NUM_GPUS_PER_NODE={gpus_per_node}
export NUM_GPUS=$((NUM_GPUS_PER_NODE*SLURM_NNODES))
DEEPSPEED_CONFIG_FILE={deepspeed}

# PATHS
CONFIG={train_config_path_out}
OUTPUT_DIR={experiments_dir}
TMP_DIR=$OUTPUT_DIR/tmp
mkdir -p $OUTPUT_DIR
mkdir -p $TMP_DIR

# ACCELERATE
ACCELERATE_CONFIG_FILE="$TMP_DIR/${SLURM_JOB_ID}_accelerate_config.yaml.autogenerated"
cat << EOT > $ACCELERATE_CONFIG_FILE
# WARNING: do not edit this file as this is an slurm-auto-generated file
compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_multinode_launcher: standard
  deepspeed_config_file: $DEEPSPEED_CONFIG_FILE
  zero3_init_flag: true
distributed_type: DEEPSPEED
fsdp_config:
machine_rank: 0
main_process_ip: $MASTER_ADDR
main_process_port: $MASTER_PORT
main_training_function: main
num_machines: $SLURM_NNODES
num_processes: $NUM_GPUS
use_cpu: false
EOT

# MAIN TRAIN COMMAND
LAUNCHER="python -u -m accelerate.commands.launch \
    --rdzv_conf "rdzv_backend=c10d,rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT" \
    --config_file $ACCELERATE_CONFIG_FILE \
    --main_process_ip $MASTER_ADDR \
    --main_process_port $MASTER_PORT \
    --machine_rank \$SLURM_PROCID \
    --role \$(hostname -s): --tee 3 \
    "
CMD="$LAUNCHER dcft/train/llamafactory/src/train.py $CONFIG"
SRUN_ARGS="
    --nodes=$NUM_NODES \
    --gres=gpu:$NUM_GPUS_PER_NODE \
    --cpus-per-task=$SLURM_CPUS_PER_TASK \
    --wait=60 \
    --kill-on-bad-exit=1 \
    --label \
    --jobid $SLURM_JOBID"

# RUN
cd $DCFT_PRIVATE
srun $SRUN_ARGS $APPTAINER_ARGS bash -c "$CMD"
rm -rf $ACCELERATE_CONFIG_FILE